package com.cat2bug.junit;

import com.cat2bug.junit.annotation.AutoTestScan;
import com.cat2bug.junit.clazz.SpringControllerTestClassFactory;
import com.cat2bug.junit.listener.Cat2BugRunListener;
import com.cat2bug.junit.service.CompileClassResultService;
import com.cat2bug.junit.service.ParameterService;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.runner.RunWith;
import org.junit.runner.notification.RunNotifier;
import org.junit.runners.Suite;
import org.junit.runners.model.RunnerBuilder;
import org.springframework.context.annotation.ClassPathScanningCandidateComponentProvider;
import org.springframework.core.type.filter.AnnotationTypeFilter;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.RestController;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
 * 用于自动创建基于测试Spring项目的测试用例Runner
 */
@RunWith(Suite.class)
public class Cat2BugAutoSpringSuite extends Suite {
	private static final Log log = LogFactory.getLog(Cat2BugAutoSpringSuite.class);
	private static Class<?>[] testClasses = null;
	private static SpringControllerTestClassFactory controllerScriptFactory = new SpringControllerTestClassFactory(); // Controller类脚本工厂

	private Class<?> klass;
	public Cat2BugAutoSpringSuite(Class<?> klass, RunnerBuilder builder) throws Exception {
		super(builder, klass, createTestControllerProxyClasses(klass));
		this.klass = klass;
	}

	@Override
	public void run(RunNotifier notifier) {
		super.run(notifier);
		notifier.addListener(new Cat2BugRunListener(this.klass));
	}

	/**
	 * 根据单元测试类创建Controller测试代理类
	 * 
	 * @param testCaseClass 单元测试类
	 * @return 返回代理类数组
	 * @throws Exception 异常
	 */
	private static Class<?>[] createTestControllerProxyClasses(Class<?> testCaseClass) throws Exception {
		if (Cat2BugAutoSpringSuite.testClasses == null) {
			String scanPackage;
			AutoTestScan atc = testCaseClass.getAnnotation(AutoTestScan.class);
			if (atc == null || atc.packageName() == null) {
				scanPackage = "";
			} else {
				scanPackage = atc.packageName();
			}
			Set<Class<?>> controllerClasses = scanControllerClass(scanPackage);
			ParameterService.getInstance().addParameterCreateClass(testCaseClass,
					controllerClasses.stream().map(item -> item.getName() + "Test").toArray(String[]::new));
			List<Class<?>> proxyClasses = new ArrayList<>();
			for (Class<?> clazz : controllerClasses) {
				try {
					Class<?> ctlClass = createProxyClass(testCaseClass, clazz);
					proxyClasses.add(ctlClass);
					CompileClassResultService.addSuccessClass(clazz);
				}catch (Exception e) {
					e.printStackTrace();
					CompileClassResultService.addFailClass(clazz,e);
				}
			}
			Cat2BugAutoSpringSuite.testClasses = proxyClasses.toArray(new Class<?>[] {});
		}
		return Cat2BugAutoSpringSuite.testClasses;
	}

	/**
	 * 根据需要测试的Controller类创建代理测试类
	 * 
	 * @param testCaseClass	测试用例类
	 * @param destTestClass 需要测试的Controller类
	 * @return 测试类
	 * @throws Exception 异常
	 */
	private static Class<?> createProxyClass(Class<?> testCaseClass, Class<?> destTestClass) throws Exception {
		Class<?> clazz = controllerScriptFactory.createTestClass(testCaseClass, destTestClass);
		return clazz;
	}

	/**
	 * 扫描指定包下的所有Controller类
	 * 
	 * @param scanPackage	扫描的包路径
	 * @return	扫描到的类
	 */
	private static Set<Class<?>> scanControllerClass(String scanPackage) {
		// 创建扫描器
		ClassPathScanningCandidateComponentProvider provider =
				new ClassPathScanningCandidateComponentProvider(false);

		// 添加注解过滤器，这里是@Controller注解
		provider.addIncludeFilter(new AnnotationTypeFilter(Controller.class));
		provider.addIncludeFilter(new AnnotationTypeFilter(RestController.class));

		Set<Class<?>> clazzs = new HashSet<>();
		// 获取匹配的Bean定义
		provider.findCandidateComponents(scanPackage).forEach(beanDefinition -> {
			try {
				clazzs.add(Class.forName(beanDefinition.getBeanClassName()));
			} catch (ClassNotFoundException e) {
				log.error(e);
			}
		});
		return clazzs;
	}
}
